"""
Phase 2: Baseline Rate and Inflation Regressions
Demographics → interest rates (levels) and inflation.
"""

import sys
from pathlib import Path
import numpy as np
import pandas as pd

# ── paths ────────────────────────────────────────────────────────────────────
PROJECT = Path("/mnt/c/demographics_capital_flows")
sys.path.insert(0, str(PROJECT / "multilateral" / "src"))
from model import PanelGLS

DATA_PATH = PROJECT / "monetary" / "data" / "processed" / "monetary_panel.csv"
TABLE_DIR = PROJECT / "monetary" / "output" / "tables"
TABLE_DIR.mkdir(parents=True, exist_ok=True)

OECD_38 = [
    "AUS", "AUT", "BEL", "CAN", "CHL", "COL", "CRI", "CZE", "DNK", "EST",
    "FIN", "FRA", "DEU", "GRC", "HUN", "ISL", "IRL", "ISR", "ITA", "JPN",
    "KOR", "LVA", "LTU", "LUX", "MEX", "NLD", "NZL", "NOR", "POL", "PRT",
    "SVK", "SVN", "ESP", "SWE", "CHE", "TUR", "GBR", "USA",
]


# ── helper ───────────────────────────────────────────────────────────────────
def run_model(panel, dep_var, rhs_vars):
    """Run PanelGLS; return dict with coefs/se/pvals/r2/nobs or None."""
    cols = [dep_var] + rhs_vars + ["iso3", "year"]
    df = panel[cols].dropna()
    if len(df) < 50:
        print(f"  SKIP {dep_var}: only {len(df)} obs after dropna")
        return None
    try:
        y = df[dep_var].values
        X = df[rhs_vars].values
        gls = PanelGLS()
        gls.fit(y, X, df["iso3"].values, df["year"].values)
        return {
            "coefs": dict(zip(rhs_vars, gls.beta)),
            "se": dict(zip(rhs_vars, gls.se)),
            "pvals": dict(zip(rhs_vars, gls.pvalues)),
            "r2": gls.r_squared,
            "nobs": gls.n_obs,
            "ncountries": gls.n_countries,
        }
    except Exception as e:
        print(f"  ERROR {dep_var}: {e}")
        return None


def star(p):
    """Significance stars."""
    if p < 0.01:
        return "***"
    elif p < 0.05:
        return "**"
    elif p < 0.10:
        return "*"
    return ""


def fmt_coef(res, var):
    """Format coefficient as 'coef(se)***'."""
    if res is None or var not in res["coefs"]:
        return ""
    c = res["coefs"][var]
    s = res["se"][var]
    p = res["pvals"][var]
    return f"{c:.3f}{star(p)} ({s:.3f})"


def build_markdown_table(title, row_vars, col_labels, results, footer_rows):
    """Build a pipe-format markdown table."""
    lines = [f"# {title}\n"]
    # header
    header = "| Variable | " + " | ".join(col_labels) + " |"
    sep = "|" + "---|" * (len(col_labels) + 1)
    lines.append(header)
    lines.append(sep)
    # rows
    for var in row_vars:
        cells = [fmt_coef(r, var) for r in results]
        lines.append(f"| {var} | " + " | ".join(cells) + " |")
    # footer
    lines.append(sep)
    for label, vals in footer_rows:
        lines.append(f"| {label} | " + " | ".join(vals) + " |")
    return "\n".join(lines)


# ── main ─────────────────────────────────────────────────────────────────────
def main():
    print("=" * 70)
    print("PHASE 2: BASELINE RATE & INFLATION REGRESSIONS")
    print("=" * 70)

    panel = pd.read_csv(DATA_PATH)
    print(f"Panel: {len(panel):,} obs, {panel['iso3'].nunique()} countries")

    oecd = panel[panel["iso3"].isin(OECD_38)].copy()
    print(f"OECD:  {len(oecd):,} obs, {oecd['iso3'].nunique()} countries")

    # ── TABLE 1: Rate Levels ─────────────────────────────────────────────
    print("\n── Table 1: Rate Levels ──")
    rate_dvs = ["real_bond_10y", "real_short_3m", "real_policy_rate", "term_spread"]
    z_vars = ["Z_1", "Z_2", "Z_3"]
    controls_rate = ["rgdp_growth", "inflation", "fiscal_bal_gdp", "kaopen", "nfa_gdp_lag"]
    rhs_rate = z_vars + controls_rate

    results_t1 = []
    col_labels_t1 = []
    for dv in rate_dvs:
        # (a) Full sample
        res_full = run_model(panel, dv, rhs_rate)
        # (b) OECD
        res_oecd = run_model(oecd, dv, rhs_rate)
        # Check if Full and OECD are effectively identical (same N)
        full_n = res_full['nobs'] if res_full else 0
        oecd_n = res_oecd['nobs'] if res_oecd else 0
        if full_n == oecd_n and full_n > 0:
            # Data is OECD-only; report once with honest label
            results_t1.append(res_full)
            col_labels_t1.append(f"{dv} (OECD*)")
            print(f"  {dv}: N={full_n} (effectively OECD-only)")
        else:
            results_t1.append(res_full)
            col_labels_t1.append(f"{dv} (Full)")
            results_t1.append(res_oecd)
            col_labels_t1.append(f"{dv} (OECD)")
            print(f"  {dv} Full: N={full_n}, OECD: N={oecd_n}")

    # Build footer
    footer_t1 = [
        ("R²", [f"{r['r2']:.3f}" if r else "" for r in results_t1]),
        ("N obs", [f"{r['nobs']:,}" if r else "" for r in results_t1]),
        ("N countries", [f"{r['ncountries']}" if r else "" for r in results_t1]),
    ]

    md_t1 = build_markdown_table(
        "Table 1: Demographics and Interest Rate Levels",
        rhs_rate, col_labels_t1, results_t1, footer_t1,
    )
    md_t1 += "\n\n*OECD\\* indicates bond yield data is available only for OECD economies; full-sample and OECD results are identical, so a single column is reported.*"
    print("\n" + md_t1)
    t1_path = TABLE_DIR / "phase2_table1_rate_levels.md"
    t1_path.write_text(md_t1)
    print(f"\nSaved: {t1_path}")

    # ── TABLE 2: Inflation ───────────────────────────────────────────────
    print("\n── Table 2: Inflation ──")
    controls_infl = ["rgdp_growth", "output_gap", "fiscal_bal_gdp", "kaopen", "nfa_gdp_lag"]
    rhs_infl = z_vars + controls_infl

    pre_gfc = panel[panel["year"] <= 2007].copy()
    post_gfc = panel[panel["year"] >= 2008].copy()

    samples_t2 = [
        ("Full", panel),
        ("OECD", oecd),
        ("Pre-GFC", pre_gfc),
        ("Post-GFC", post_gfc),
    ]

    results_t2 = []
    col_labels_t2 = []
    for label, samp in samples_t2:
        res = run_model(samp, "inflation", rhs_infl)
        results_t2.append(res)
        col_labels_t2.append(label)
        print(f"  inflation {label}: N={res['nobs'] if res else 'NA'}")

    footer_t2 = [
        ("R²", [f"{r['r2']:.3f}" if r else "" for r in results_t2]),
        ("N obs", [f"{r['nobs']:,}" if r else "" for r in results_t2]),
        ("N countries", [f"{r['ncountries']}" if r else "" for r in results_t2]),
    ]

    md_t2 = build_markdown_table(
        "Table 2: Demographics and Inflation",
        rhs_infl, col_labels_t2, results_t2, footer_t2,
    )
    print("\n" + md_t2)
    t2_path = TABLE_DIR / "phase2_table2_inflation.md"
    t2_path.write_text(md_t2)
    print(f"\nSaved: {t2_path}")

    print("\nPhase 2 complete.")


if __name__ == "__main__":
    main()
